{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "view-in-github"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_pop.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "47d5313e-c29d-4581-a9c7-a45122337069",
      "metadata": {
        "id": "47d5313e-c29d-4581-a9c7-a45122337069"
      },
      "source": [
        "[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/snntorch_alpha_w.png?raw=true' width=\"400\">](https://github.com/jeshraghian/snntorch/)\n",
        "\n",
        "# snnTorch - Population Coding in Spiking Neural Nets\n",
        "## By Jason K. Eshraghian (www.jasoneshraghian.com)\n",
        "\n",
        "<a href=\"https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_pop.ipynb\">\n",
        "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
        "</a>\n",
        "\n",
        "[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub-Mark-Light-120px-plus.png?raw=true' width=\"28\">](https://github.com/jeshraghian/snntorch/) [<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub_Logo_White.png?raw=true' width=\"80\">](https://github.com/jeshraghian/snntorch/)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "oll2NNFeG1NG",
      "metadata": {
        "id": "oll2NNFeG1NG"
      },
      "source": [
        "For a comprehensive overview on how SNNs work, and what is going on under the hood, [then you might be interested in the snnTorch tutorial series available here.](https://snntorch.readthedocs.io/en/latest/tutorials/index.html)\n",
        "The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source:\n",
        "\n",
        "> <cite> [Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. \"Training Spiking Neural Networks Using Lessons From Deep Learning\". Proceedings of the IEEE, 111(9) September 2023.](https://ieeexplore.ieee.org/abstract/document/10242251) </cite>"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "6w8ThNNEYn5i",
      "metadata": {
        "id": "6w8ThNNEYn5i"
      },
      "source": [
        "# Introduction\n",
        "It is thought that rate codes alone cannot be the dominant encoding mechanism in the primary cortex. One of several reasons is because the average neuronal firing rate is roughly $0.1-1$ Hz, which is far slower than the reaction response time of animals and humans.\n",
        "\n",
        "But if we pool together multiple neurons and count their spikes together, then it becomes possible to measure a firing rate for a population of neurons in a very short window of time. Population coding adds some credibility to the plausibility of rate-encoding mechanisms.\n",
        "\n",
        "<center>\n",
        "<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial_pop/pop.png?raw=true' width=\"300\">\n",
        "</center>\n",
        "\n",
        "\n",
        "In this tutorial, you will:\n",
        "* Learn how to train a population coded network. Instead of assigning one neuron per class, we will extend this to multiple neurons per class, and aggregate their spikes together.\n",
        "\n",
        "If running in Google Colab:\n",
        "* You may connect to GPU by checking `Runtime` > `Change runtime type` > `Hardware accelerator: GPU`\n",
        "* Next, install the latest PyPi distribution of snnTorch by clicking into the following cell and pressing `Shift+Enter`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "hDnIEHOKB8LD",
      "metadata": {
        "id": "hDnIEHOKB8LD"
      },
      "outputs": [],
      "source": [
        "!pip install snntorch --quiet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "WL487gZW1Agy",
      "metadata": {
        "id": "WL487gZW1Agy"
      },
      "outputs": [],
      "source": [
        "import torch, torch.nn as nn\n",
        "import snntorch as snn"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "EYf13Gtx1OCj",
      "metadata": {
        "id": "EYf13Gtx1OCj"
      },
      "source": [
        "# DataLoading\n",
        "Define variables for dataloading."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "eo4T5MC21hgD",
      "metadata": {
        "id": "eo4T5MC21hgD"
      },
      "outputs": [],
      "source": [
        "batch_size = 128\n",
        "data_path='/tmp/data/fmnist'\n",
        "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else torch.device(\"cpu\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "myFKqNx11qYS",
      "metadata": {
        "id": "myFKqNx11qYS"
      },
      "source": [
        "Load FashionMNIST dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "3GdglZjK04cb",
      "metadata": {
        "id": "3GdglZjK04cb"
      },
      "outputs": [],
      "source": [
        "from torch.utils.data import DataLoader\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "# Define a transform\n",
        "transform = transforms.Compose([\n",
        "            transforms.Resize((28, 28)),\n",
        "            transforms.Grayscale(),\n",
        "            transforms.ToTensor(),\n",
        "            transforms.Normalize((0,), (1,))])\n",
        "\n",
        "fmnist_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transform)\n",
        "fmnist_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=transform)\n",
        "\n",
        "# Create DataLoaders\n",
        "train_loader = DataLoader(fmnist_train, batch_size=batch_size, shuffle=True)\n",
        "test_loader = DataLoader(fmnist_test, batch_size=batch_size, shuffle=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "BtJBOtez11wy",
      "metadata": {
        "id": "BtJBOtez11wy"
      },
      "source": [
        "# Define Network\n",
        "Let's compare the performance of a pair of networks both with and without population coding, and train them for *one single time step.*\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Vx7Y3SH-gKzB",
      "metadata": {
        "id": "Vx7Y3SH-gKzB"
      },
      "outputs": [],
      "source": [
        "from snntorch import surrogate\n",
        "\n",
        "# network parameters\n",
        "num_inputs = 28*28\n",
        "num_hidden = 128\n",
        "num_outputs = 10\n",
        "num_steps = 1\n",
        "\n",
        "# spiking neuron parameters\n",
        "beta = 0.9  # neuron decay rate \n",
        "grad = surrogate.fast_sigmoid()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "DukdBcgFgBFW",
      "metadata": {
        "id": "DukdBcgFgBFW"
      },
      "source": [
        "## Without population coding\n",
        "Let's just use a simple 2-layer dense spiking network."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8ipqtuRsgCmJ",
      "metadata": {
        "id": "8ipqtuRsgCmJ"
      },
      "outputs": [],
      "source": [
        "net = nn.Sequential(nn.Flatten(),\n",
        "                    nn.Linear(num_inputs, num_hidden),\n",
        "                    snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True),\n",
        "                    nn.Linear(num_hidden, num_outputs),\n",
        "                    snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True, output=True)\n",
        "                    ).to(device)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a_si0NREgDnl",
      "metadata": {
        "id": "a_si0NREgDnl"
      },
      "source": [
        "## With population coding\n",
        "Instead of 10 output neurons corresponding to 10 output classes, we will use 500 output neurons. This means each output class has 50 neurons randomly assigned to it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "JM2thnrc10rD",
      "metadata": {
        "id": "JM2thnrc10rD"
      },
      "outputs": [],
      "source": [
        "pop_outputs = 500\n",
        "\n",
        "net_pop = nn.Sequential(nn.Flatten(),\n",
        "                        nn.Linear(num_inputs, num_hidden),\n",
        "                        snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True),\n",
        "                        nn.Linear(num_hidden, pop_outputs),\n",
        "                        snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True, output=True)\n",
        "                        ).to(device)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "sIrJnBoz490c",
      "metadata": {
        "id": "sIrJnBoz490c"
      },
      "source": [
        "# Training\n",
        "## Without population coding\n",
        "Define the optimizer and loss function. Here, we use the MSE Count Loss, which counts up the total number of output spikes at the end of the simulation run. \n",
        "\n",
        "The correct class has a target firing probability of 100%, and incorrect classes are set to 0%. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "VocYbtD7Vwp7",
      "metadata": {
        "id": "VocYbtD7Vwp7"
      },
      "outputs": [],
      "source": [
        "import snntorch.functional as SF\n",
        "\n",
        "optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))\n",
        "loss_fn = SF.mse_count_loss(correct_rate=1.0, incorrect_rate=0.0)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "w935Eglbh6bb",
      "metadata": {
        "id": "w935Eglbh6bb"
      },
      "source": [
        "We will also define a simple test accuracy function that predicts the correct class based on the neuron with the highest spike count. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0XkNIs7sh2uv",
      "metadata": {
        "id": "0XkNIs7sh2uv"
      },
      "outputs": [],
      "source": [
        "from snntorch import utils\n",
        "\n",
        "def test_accuracy(data_loader, net, num_steps, population_code=False, num_classes=False):\n",
        "  with torch.no_grad():\n",
        "    total = 0\n",
        "    acc = 0\n",
        "    net.eval()\n",
        "\n",
        "    data_loader = iter(data_loader)\n",
        "    for data, targets in data_loader:\n",
        "      data = data.to(device)\n",
        "      targets = targets.to(device)\n",
        "      utils.reset(net)\n",
        "      spk_rec, _ = net(data)\n",
        "\n",
        "      if population_code:\n",
        "        acc += SF.accuracy_rate(spk_rec.unsqueeze(0), targets, population_code=True, num_classes=10) * spk_rec.size(1)\n",
        "      else:\n",
        "        acc += SF.accuracy_rate(spk_rec.unsqueeze(0), targets) * spk_rec.size(1)\n",
        "        \n",
        "      total += spk_rec.size(1)\n",
        "\n",
        "  return acc/total"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "57U_9vy1W05U",
      "metadata": {
        "id": "57U_9vy1W05U"
      },
      "source": [
        "Let's run the training loop. Note that we are only training for $1$ time step. I.e., each neuron only has the opportunity to fire once. As a result, we might not expect the network to perform too well here."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "tQZBPV0pgkx0",
      "metadata": {
        "id": "tQZBPV0pgkx0"
      },
      "outputs": [],
      "source": [
        "from snntorch import backprop\n",
        "\n",
        "num_epochs = 5\n",
        "\n",
        "# training loop\n",
        "for epoch in range(num_epochs):\n",
        "\n",
        "    avg_loss = backprop.BPTT(net, train_loader, num_steps=num_steps,\n",
        "                          optimizer=optimizer, criterion=loss_fn, time_var=False, device=device)\n",
        "    \n",
        "    print(f\"Epoch: {epoch}\")\n",
        "    print(f\"Test set accuracy: {test_accuracy(test_loader, net, num_steps)*100:.3f}%\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "zT_HQpA5XcN0",
      "metadata": {
        "id": "zT_HQpA5XcN0"
      },
      "source": [
        "While there are ways to improve single time-step performance, e.g., by applying the loss to the membrane potential, one single time-step is extremely challenging to train a network competitively using rate codes."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "vT-IuyY_LauT",
      "metadata": {
        "id": "vT-IuyY_LauT"
      },
      "source": [
        "## With population coding\n",
        "Let's modify the loss function to specify that population coding should be enabled. We must also specify the number of classes. This means that there will be a total of $50~neurons~per~class~=~500~neurons~/~10~classes$."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "YH-rAAxfLd_T",
      "metadata": {
        "id": "YH-rAAxfLd_T"
      },
      "outputs": [],
      "source": [
        "loss_fn = SF.mse_count_loss(correct_rate=1.0, incorrect_rate=0.0, population_code=True, num_classes=10)\n",
        "optimizer = torch.optim.Adam(net_pop.parameters(), lr=2e-3, betas=(0.9, 0.999))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "afyJpc5vM7qN",
      "metadata": {
        "id": "afyJpc5vM7qN"
      },
      "outputs": [],
      "source": [
        "num_epochs = 5\n",
        "\n",
        "# training loop\n",
        "for epoch in range(num_epochs):\n",
        "\n",
        "    avg_loss = backprop.BPTT(net_pop, train_loader, num_steps=num_steps,\n",
        "                            optimizer=optimizer, criterion=loss_fn, time_var=False, device=device)\n",
        "\n",
        "    print(f\"Epoch: {epoch}\")\n",
        "    print(f\"Test set accuracy: {test_accuracy(test_loader, net_pop, num_steps, population_code=True, num_classes=10)*100:.3f}%\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "Jfrxe1XWYxE7",
      "metadata": {
        "id": "Jfrxe1XWYxE7"
      },
      "source": [
        "Even though we are only training on one time-step, introducing additional output neurons has immediately enabled better performance."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "-iSGTq0Q3Lcm",
      "metadata": {
        "id": "-iSGTq0Q3Lcm"
      },
      "source": [
        "# Conclusion\n",
        "The performance boost from population coding may start to fade as the number of time steps increases. But it may also be preferable to increasing time steps as PyTorch is optimized for handling matrix-vector products, rather than sequential, step-by-step operations over time. \n",
        "\n",
        "* For a detailed tutorial of spiking neurons, neural nets, encoding, and training using neuromorphic datasets, check out the\n",
        "[snnTorch tutorial series](https://snntorch.readthedocs.io/en/latest/tutorials/index.html).\n",
        "* For more information on the features of snnTorch, check out the [documentation at this link](https://snntorch.readthedocs.io/en/latest/).\n",
        "* If you have ideas, suggestions or would like to find ways to get involved, then [check out the snnTorch GitHub project here.](https://github.com/jeshraghian/snntorch)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "include_colab_link": true,
      "name": "snntorch_population_codes.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.11"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
